{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "from torch.distributions.multivariate_normal import MultivariateNormal\n",
    "from torch.distributions.mixture_same_family import MixtureSameFamily\n",
    "from torch.distributions.categorical import Categorical\n",
    "\n",
    "%matplotlib widget"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gaussian Mixture Model\n",
    "\n",
    "Let us suppose we are given the weights and heights of Statsville residents. They are three classes of residents in Statsville, women, men and children. The data is unlabelled, meaning we do not know whether a given instance of (height, weight) data is associated with a man or a woman or a child.\n",
    "\n",
    "If we assume that, considered individually, the distribution of each class is Gaussian. This means heights of men, women and children, each follows 2D Gaussian distribution.\n",
    "\n",
    "Our task is to come up with a model that can identify the category (man or woman or child) from weight and height."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A Gaussian Mixture Model (GMM) is a weighted combination of a specific number of Gaussian components. They are used to represent naturally multi-classed phenomenons. Each Gaussian component corresponds to a specific class.\n",
    "\n",
    "A $K$ component GMM is typically used to model a phenomenon with $K$ sub-classes each of which can be modelled by a Gaussian. \n",
    "\n",
    "$$\n",
    "p\\left(\\vec{x}\\right) =\\sum_{k=1}^{K} p\\left(\\vec{x}, Z = k\\right) \n",
    "                       = \\sum_{i=1}^{K} p\\left(Z = k\\right)$  $p\\left(\\vec{x} \\middle\\vert Z= k\\right)\n",
    "                       = \\sum_{k=1}^{K} \\pi_{k} \\; \\mathcal{N}\\left(\\vec{x}; \\, \\vec{\\mu_{k}}, \\mathbf{\\Sigma}_{k}\\right)$$\n",
    "                       \n",
    "The $k^{th}$component Gaussian, $\\mathcal{N}\\left(X = \\vec{x}; \\, \\vec{\\mu_{k}}, \\mathbf{\\Sigma}_{k}\\right)$ can be interpreted as the conditional probability, $p\\left(X=\\vec{x} \\middle\\vert Z = k\\right) $, of  data value $\\vec{x}$ occurring, given the $k^{th}$  sub-class.\n",
    "\n",
    "The weight $\\pi_{k}$, can be interpreted as the prior  probability of selecting the $k^{th}$ sub-class, i.e., the fraction of  population belonging to the $k^{th}$ sub-class.\n",
    "\n",
    "Let us try and model the unlabelled heights and weights data with a 3 component GMM and see if the learnt clusters correspond to the expected categories."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In real life scenarios, we will not know the true values of the parameters (means, variances of the of the Gaussian distributions. However, for the sake of understanding, let us suppose know the true parameters governing the height and width values of the 3 categories.\n",
    "\n",
    "Specifically: \\\n",
    "$\\vec{\\mu}_{man}=\\begin{bmatrix}\n",
    "  175\\\\70\n",
    "  \\end{bmatrix}$,\n",
    "  $\\mathbf{\\Sigma}_{man} = \\begin{bmatrix}\n",
    "  8 & 0\\\\0 & 25\n",
    "  \\end{bmatrix}$,\n",
    "  $\\mathbf{\\pi}_{man} = 0.4$, \\\n",
    "$\\vec{\\mu}_{woman}=\\begin{bmatrix}\n",
    "  152\\\\55\n",
    "  \\end{bmatrix}$,\n",
    "  $\\mathbf{\\Sigma}_{woman} = \\begin{bmatrix}\n",
    "  8 & 0\\\\0 & 15\n",
    "  \\end{bmatrix}$,\n",
    "  $\\mathbf{\\pi}_{woman} = 0.4$\\\n",
    "$\\vec{\\mu}_{child}=\\begin{bmatrix}\n",
    "  135\\\\40\n",
    "  \\end{bmatrix}$,\n",
    "  $\\mathbf{\\Sigma}_{child} = \\begin{bmatrix}\n",
    "  5 & 0\\\\0 & 5\n",
    "  \\end{bmatrix}$\n",
    "  $\\mathbf{\\pi}_{child} = 0.2$\n",
    "  \n",
    "  \n",
    "We will use these parameters to generate an unlabelled sample of data. And then try to fit a GMM on this data and verify that the learnt GMM essentially learns these parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_gmm(pi, mu, sigma):\n",
    "    mix = Categorical(pi)\n",
    "    components = MultivariateNormal(mu, sigma)\n",
    "    # A GMM can be viewed as a mixture of multiple normal distribtutions (components). The probabilty of \n",
    "    # each component is defined by a categorical distribution, pi (mix) \n",
    "    gmm = MixtureSameFamily(mix, components)\n",
    "    return gmm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# For the purposes of data generation, let us assume we already know the \n",
    "# true values of parameters\n",
    "mu_true = torch.tensor([[175.0, 70.0], \n",
    "                        [152.0, 55.0], \n",
    "                        [135.0, 40.0]])\n",
    "\n",
    "sigma_true = torch.tensor([[[8.0, 10.0], [10.0, 25.0]],\n",
    "                           [[8.0, 0.0], [0.0, 15.0]],\n",
    "                           [[5.0, 0.0], [0.0, 5.0]]])\n",
    "\n",
    "pi_true = torch.tensor([0.4, 0.4, 0.2])\n",
    "\n",
    "data_generator = create_gmm(pi_true, mu_true, sigma_true)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let us assume we have a sample of heights and weights from the residents of Statsville\n",
    "# And our task is to model this data with a GMM that has 3 components.\n",
    "num_samples = 1000\n",
    "\n",
    "# Generate the unlabelled samples from the \"known\" GMM. Note that in real world examples\n",
    "# we will not have a known GMM. Instead we will just get a sample of unlabelled data.\n",
    "X = data_generator.sample([num_samples])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_data(X_list, title, colors=['red']):\n",
    "    fig, ax = plt.subplots()\n",
    "    for i, X in enumerate(X_list):\n",
    "        ax.scatter(X[:, 0], X[:, 1], c=colors[i], s=15, edgecolors='black')\n",
    "    ax.set_title(title)\n",
    "    ax.set_xlabel(\"Height in cm\")\n",
    "    ax.set_ylabel(\"Weight in kg\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "343e5d909fc5496bad4b89b6c6cf8af1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# As we can see clearly, there exist 3 distinct clusters in the underlying data\n",
    "# We would expect that these clusters correspond to the 3 true classes i.e men, women and children\n",
    "plot_data([X], \"Data distribution\", colors=[\"orange\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_gmm_components(gmm, num_samples, step):\n",
    "    pi = gmm.mixture_distribution.probs\n",
    "    components = gmm.component_distribution\n",
    "    # Sample num samples from every component distribution\n",
    "    X_components = components.sample([num_samples])\n",
    "    num_component_samples = pi * num_samples\n",
    "\n",
    "    X_list = []\n",
    "    for i in range(pi.shape[0]):\n",
    "        # For every component, pick pi * num samples \n",
    "        X_list.append(X_components[:, i, :][:int(num_component_samples[i])])\n",
    "    plot_data(X_list, colors=[\"red\", \"blue\", \"green\"], title=\"Step {}\".format(step))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We want to model this data with a GMM with 3 classes\n",
    "# Let us start with a random initialization\n",
    "pi_init = torch.tensor([0.33, 0.33, 0.33])\n",
    "mu_init = torch.tensor([[100.0, 30.0], [160, 50.0], [200.0, 100.0]])\n",
    "sigma_init = torch.tensor([[[10.0, 0.0], [0.0, 10.0]],[[10.0, 0.0], [0.0, 10.0]], [[10.0, 0.0], [0.0, 10.0]]])\n",
    "\n",
    "gmm_init = create_gmm(pi_init, mu_init, sigma_init)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Maximum Likelihood Estimation of GMM parameters (GMM Fit)\n",
    "\n",
    "Maximum likelihood estimation of a GMM is done iteratively via Expectation Maximization. We alternately perform the E and the M steps until the likelihood increase is negligible\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### E Step\n",
    "In the E step, we compute the posterior probabilities of the points belonging to the individual clusters\n",
    "\n",
    "$$\\gamma_{ik} =  \\frac{\\pi_{k} * N(\\vec{x}_{i}; \\mu_{k}, \\Sigma_{k})}{\\sum_{k=1}^{K} \\pi_{k} * N(\\vec{x}_{i}; \\mu_{k}, \\Sigma_{k})} $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def E_step(X, gmm):\n",
    "    pi = gmm.mixture_distribution.probs \n",
    "    components = gmm.component_distribution\n",
    "    n = X.shape[0]\n",
    "    # In practice, when dealing with exponential functions, we typically work in the log scale\n",
    "    # This is because the exponential will go to 0 or inf at finite values of x leading to \n",
    "    # overflow/ underflow. Working with logarithms leads to numerical stability\n",
    "    log_gamma_numerators = components.log_prob(X.unsqueeze(1)) + torch.log(pi).repeat(n, 1)\n",
    "    log_gamma_denominators = torch.logsumexp(log_gamma_numerators, dim=1, keepdim=True)\n",
    "    log_gamma = log_gamma_numerators - log_gamma_denominators\n",
    "    gamma = torch.exp(log_gamma)\n",
    "    return gamma"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### M Step\n",
    "In the M step we compute the new values of the parameters $\\mu, \\Sigma$ and $\\pi$\n",
    "\n",
    "$$ N_{k} = \\sum_{i=1}^{n}  \\gamma_{ik} $$\n",
    "$$\\pi_{k} = \\frac{N_{k}}{n} $$\n",
    "$$\\vec{\\mu}_{k}  = \\frac{1}{N_{k}} \\sum_{i=1}^{n}  \\gamma_{ik} \\vec{x}_{i}$$\n",
    "$${\\Sigma}_{k} =  \\frac{1}{N_{k}} \\sum_{i=1}^{n}  \\gamma_{ik} \\left( \\vec{x}_{i} - \\vec{\\mu}_{k} \\right)  \\left( \\vec{x}_{i} - \\vec{\\mu}_{k} \\right)^{T}$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def M_step(X, gmm, gamma):\n",
    "    pi_old = gmm.mixture_distribution.probs\n",
    "    components_old = gmm.component_distribution\n",
    "    \n",
    "    K = components_old.batch_shape[0]    #number of mixture components\n",
    "    d = components_old.event_shape[0]    #data dimensionality\n",
    "    n = X.shape[0]  \n",
    "    N = torch.sum(gamma, 0)\n",
    "    pi_new = N / n\n",
    "    mu_new = ((X.T @ gamma)/N).T\n",
    "    \n",
    "\n",
    "    mu_old = components_old.mean\n",
    "    x_minus_mu = (X.repeat(K, 1, 1) - mu_old.unsqueeze(1).repeat(1, n, 1))\n",
    "    x_minus_mu_squared = x_minus_mu.unsqueeze(3) @ x_minus_mu.unsqueeze(2)\n",
    "    sigma_new =  torch.sum(gamma.T.unsqueeze(2).unsqueeze(3) * x_minus_mu_squared, axis=1) / N.unsqueeze(1).unsqueeze(1).repeat(1, d, d)\n",
    "    return create_gmm(pi_new, mu_new, sigma_new)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gmm_fit(X, gmm_init):\n",
    "    curr_gmm = gmm_init\n",
    "    eps = 1e-4\n",
    "    prev_likelihood, curr_likelihood = None, None\n",
    "    step = 0\n",
    "    while True:\n",
    "        if curr_likelihood is None:\n",
    "            curr_likelihood = torch.sum(curr_gmm.log_prob(X))\n",
    "        # E Step\n",
    "        gamma = E_step(X, curr_gmm) \n",
    "        # M Step\n",
    "        curr_gmm = M_step(X, curr_gmm, gamma)\n",
    "        prev_likelihood = curr_likelihood\n",
    "        # Compute new likelihood\n",
    "        curr_likelihood = torch.sum(curr_gmm.log_prob(X))\n",
    "        likelihood_increase = curr_likelihood - prev_likelihood\n",
    "        if step % 5 == 0:\n",
    "            plot_gmm_components(curr_gmm, X.shape[0], step)\n",
    "        step +=1\n",
    "        if likelihood_increase < eps:\n",
    "            break\n",
    "    plot_gmm_components(curr_gmm, X.shape[0], step)\n",
    "    return curr_gmm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d4700ec48834513b9521bccf76bbc80",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7194b1225ee24863ae6255db25fa7747",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0f16387a7da744bb8efb3aa92873cb95",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9f29b9581a8e4f75b0a03cea5886c18f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59d708df102b4c31ae8b6d944a7c6f0b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00820ffa11df4ef8859069d72ae68ea2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fitted_gmm = gmm_fit(X, gmm_init)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitted pi: tensor([0.2100, 0.3960, 0.3940])\n",
      "Fitted means: tensor([[135.0295,  39.9497],\n",
      "        [151.9347,  55.0202],\n",
      "        [174.6390,  69.6488]])\n",
      "Fitted sigmas: tensor([[[5.3559e+00, 6.0412e-01],\n",
      "         [6.0412e-01, 5.9280e+00]],\n",
      "\n",
      "        [[8.0762e+00, 2.0353e-03],\n",
      "         [2.0353e-03, 1.4447e+01]],\n",
      "\n",
      "        [[7.3492e+00, 9.8788e+00],\n",
      "         [9.8788e+00, 2.5764e+01]]])\n"
     ]
    }
   ],
   "source": [
    "# Let us look at the learnt GMM parameters\n",
    "# Note that these are similar to the values of pi_init, mu_init and sigma_init\n",
    "print(\"Fitted pi: {}\".format(fitted_gmm.mixture_distribution.probs))\n",
    "print(\"Fitted means: {}\".format(fitted_gmm.component_distribution.mean))\n",
    "print(\"Fitted sigmas: {}\".format(fitted_gmm.component_distribution.covariance_matrix))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification\n",
    "\n",
    "If we have the parameters of the GMM modeling the data, we can predict the class/ component of new datapoints. For that, we need to evaluate the conditional probability  for each class given the input and then choose the class with the highest value of this conditional probability. This conditional probability effectively measures the closeness of  input to the $k^{th}$ cluster, relative to the other clusters.\n",
    "\n",
    "The conditional probability is defined as follows:\n",
    "$$ \n",
    "p\\left(Z = k \\middle\\vert X = x\\right) = \\frac{p\\left(x, k\\right) }{ \\sum_{k \\in \\left\\{1, K\\right\\}} p\\left(x, k\\right) } = \\frac{ \\pi_{k} \\mathcal{N}\\left(\\vec{x}; \\, \\vec{\\mu_{k}}, \\mathbf{\\Sigma}_{k}\\right) }{\\sum_{k=1}^{K} \\pi_{k} \\mathcal{N}\\left(\\vec{x}; \\, \\vec{\\mu_{k}}, \\mathbf{\\Sigma}_{k}\\right) }\n",
    "$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classify(gmm, X):\n",
    "    gamma = E_step(X, gmm) #Computing the conditional probability is the same as E Step\n",
    "    classes = gamma.argmax(axis=1)# We assign each point the component with highest probability\n",
    "    return classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_new = torch.tensor([[135.0, 40.0], [152.0, 55.0], [175.0, 70.0]])\n",
    "assert torch.equal(classify(fitted_gmm, X_new), torch.tensor([0, 1, 2]) )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
